#  Copyright 2025 Collate
#  Licensed under the Collate Community License, Version 1.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
Validator for table custom SQL Query test case
"""

import traceback
from abc import abstractmethod
from enum import Enum
from typing import cast

from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.generated.schema.tests.basic import (
    TestCaseResult,
    TestCaseStatus,
    TestResultValue,
)
from metadata.utils.helpers import evaluate_threshold
from metadata.utils.logger import test_suite_logger

logger = test_suite_logger()

RESULT_ROW_COUNT = "resultRowCount"


class Strategy(Enum):
    COUNT = "COUNT"
    ROWS = "ROWS"


class BaseTableCustomSQLQueryValidator(BaseTestValidator):
    """Validator table custom SQL Query test case"""

    def _run_validation(self) -> TestCaseResult:
        """Execute the specific test validation logic

        This method contains the core validation logic that was previously
        in the run_validation method.

        Returns:
            TestCaseResult: The test case result for the overall validation
        """
        sql_expression = self.get_test_case_param_value(
            self.test_case.parameterValues,  # type: ignore
            "sqlExpression",
            str,
        )

        operator = self.get_test_case_param_value(
            self.test_case.parameterValues, "operator", str, "<="  # type: ignore
        )

        threshold = self.get_test_case_param_value(
            self.test_case.parameterValues,  # type: ignore
            "threshold",
            int,
            default=0,
        )

        strategy = self.get_test_case_param_value(
            self.test_case.parameterValues,  # type: ignore
            "strategy",
            Strategy,
        )

        operator = cast(str, operator)  # satisfy mypy
        sql_expression = cast(str, sql_expression)  # satisfy mypy
        threshold = cast(int, threshold)  # satisfy mypy
        strategy = cast(Strategy, strategy)  # satisfy mypy

        try:
            rows = self._run_results(sql_expression, strategy)
        except Exception as exc:
            msg = f"Error computing {self.test_case.fullyQualifiedName}: {exc}"  # type: ignore
            logger.debug(traceback.format_exc())
            logger.warning(msg)
            return self.get_test_case_result_object(
                self.execution_date,
                TestCaseStatus.Aborted,
                msg,
                [TestResultValue(name=RESULT_ROW_COUNT, value=None)],
            )
        len_rows = rows if isinstance(rows, int) else len(rows)
        test_passed = evaluate_threshold(
            threshold,
            operator,
            len_rows,
        )

        if test_passed:
            status = TestCaseStatus.Success
            result_value = len_rows
        else:
            status = TestCaseStatus.Failed
            result_value = len_rows

        if self.test_case.computePassedFailedRowCount:
            row_count = self._get_total_row_count_if_needed()
            passed_rows, failed_rows = self._calculate_passed_failed_rows(
                test_passed, operator, threshold, len_rows, row_count
            )
        else:
            passed_rows = None
            failed_rows = None
            row_count = None

        return self.get_test_case_result_object(
            self.execution_date,
            status,
            f"Found {result_value} row(s). Test query is expected to return {operator} {threshold} row(s).",
            [TestResultValue(name=RESULT_ROW_COUNT, value=str(result_value))],
            row_count=row_count,
            failed_rows=failed_rows,
            passed_rows=passed_rows,
        )

    @abstractmethod
    def _run_results(self, sql_expression: str, strategy: Strategy = Strategy.ROWS):
        raise NotImplementedError

    @abstractmethod
    def compute_row_count(self):
        """Compute row count for the given column

        Raises:
            NotImplementedError:
        """
        raise NotImplementedError

    def get_row_count(self) -> int:
        """Get row count

        Returns:
            Tuple[int, int]:
        """
        return self.compute_row_count()

    def _get_total_row_count_if_needed(self) -> int:
        """Get total row count if computePassedFailedRowCount is enabled"""
        return self.get_row_count()

    def _calculate_passed_failed_rows(
        self,
        test_passed: bool,
        operator: str,
        threshold: int,
        len_rows: int,
        row_count: int,
    ) -> tuple[int, int]:
        """Calculate passed and failed rows based on test result and operator

        Args:
            test_passed: Whether the test passed
            operator: Comparison operator (>, >=, <, <=, ==)
            threshold: Expected threshold value
            len_rows: Number of rows returned by the test query
            row_count: Total number of rows in the table (or None)

        Returns:
            Tuple of (passed_rows, failed_rows)
        """
        if test_passed:
            return self._calculate_passed_rows_success(operator, len_rows, row_count)
        return self._calculate_passed_rows_failure(
            operator, threshold, len_rows, row_count
        )

    def _calculate_passed_rows_success(
        self, operator: str, len_rows: int, row_count: int
    ) -> tuple[int, int]:
        """Calculate passed/failed rows when test passed"""
        if operator in (">", ">="):
            passed_rows = len_rows
            failed_rows = (row_count - len_rows) if row_count else 0
        elif operator in ("<", "<="):
            passed_rows = row_count - len_rows
            failed_rows = len_rows
        elif operator == "==":
            passed_rows = len_rows
            failed_rows = row_count - len_rows
        else:
            passed_rows = len_rows
            failed_rows = 0

        return max(0, passed_rows), max(0, failed_rows)

    def _calculate_passed_rows_failure(
        self, operator: str, threshold: int, len_rows: int, row_count: int
    ) -> tuple[int, int]:
        """Calculate passed/failed rows when test failed"""
        if operator in (">", ">="):
            return self._calculate_greater_than_failure(len_rows, row_count)
        if operator in ("<", "<="):
            return self._calculate_less_than_failure(len_rows, row_count)
        if operator == "==":
            return self._calculate_equal_failure(threshold, len_rows, row_count)

        failed_rows = row_count if row_count else len_rows
        return 0, max(0, failed_rows)

    def _calculate_greater_than_failure(
        self, len_rows: int, row_count: int
    ) -> tuple[int, int]:
        """Calculate rows for > or >= operator failure (expected more rows)"""
        passed_rows = len_rows
        failed_rows = (row_count - len_rows) if row_count else 0
        return max(0, passed_rows), max(0, failed_rows)

    def _calculate_less_than_failure(
        self, len_rows: int, row_count: int
    ) -> tuple[int, int]:
        """Calculate rows for < or <= operator failure (expected fewer rows)"""
        failed_rows = len_rows
        passed_rows = row_count - failed_rows

        return max(0, passed_rows), max(0, failed_rows)

    def _calculate_equal_failure(
        self, threshold: int, len_rows: int, row_count: int
    ) -> tuple[int, int]:
        """Calculate rows for == operator failure (expected exact count)"""
        if row_count:
            if len_rows > threshold:
                failed_rows = len_rows - threshold
                passed_rows = row_count - failed_rows
            else:
                failed_rows = row_count - len_rows
                passed_rows = len_rows
        else:
            failed_rows = abs(len_rows - threshold)
            passed_rows = 0

        return max(0, passed_rows), max(0, failed_rows)
